source(here::here("R/00_setup.R"))

d_baseline = read_csv(here("data/precinct_baseline.csv"), show_col_types=FALSE)
model = read_rds(here("data/election_model.rds"))
d_control = read_csv(here("data/control.csv"), show_col_types=FALSE)

d_st_base = d_baseline |>
  group_by(state) |>
  summarize(baseline = log(sum(ndv)) - log(sum(nrv)))

# set up Gauss-Hermite integration matrix, scaled appropriately
gh = lme4::GHrule(4, asMatrix=FALSE) |>
  mutate(z = z * model$natl_sd)
# matrix of logit D vote shares
d_seats = map_dfr(cli_progress_along(state.abb), function(i) {
  abbr = state.abb[i]
  d_st = filter(d_baseline, state == abbr)
  plans = load_50state_plans(abbr)

  ldvs = part_dvs(plans, d_st, ndv, nrv) |>
    qlogis()
  seats = outer(ldvs, gh$z, `+`) |>
    model$prob_win_logit() |>
    rowWeightedMeans(w=gh$w) |>
    matrix(ncol=5001) |>
    colSums()
  new_tibble(list(
    state = rep(abbr, 5000),
    seats_enac = rep(seats[1], 5000),
    seats = seats[-1],
    draw = 1:5000
  ))
}) |>
  left_join(d_control, by="state") |>
  left_join(d_st_base, by="state")

d_natl = group_by(d_seats, draw) |>
  summarize(label = "NATIONAL",
            seats_enac = sum(seats_enac),
            seats = sum(seats),
            diff = 1000,
            control = NA_character_)

l_party = d_seats |>
  mutate(gap = seats_enac - seats,
         party = if_else(gap > 0, "dem_tot", "rep_tot")) |>
  group_by(party, draw) |>
  summarize(gap = sum(gap)) |>
  summarize(gap = median(gap)) |>
  pivot_wider(names_from=party, values_from=gap) |>
  as.list()

c(
  l_party,
  list(natl_gap = median(d_natl$seats) - mean(d_natl$seats_enac))
) |>
  write_json(here("paper/stats_state.json"), pretty=TRUE, auto_unbox=TRUE, digits=6)

recode_control = c("Dems"="Democrats",
                   "Both Parties"="Mixed",
                   "GOP"="Republicans",
                   "Independent"="Commission",
                   "Court"="Court")
PAL_control = c(Democrats=ggredist$partisan[15],
                Mixed="#dd55dd",
                Republicans=ggredist$partisan[1],
                Commission="#77bb88",
                Court="#ccbb33")

label_seats = function(x) {
  case_when(x == 0 ~ "No difference",
            x < 0 ~ number(-x, 1, prefix="D+"),
            x > 0 ~ number(x, 1, prefix="R+"))
}

ap_abbr = tibble(
  ap = c("Ala.","Alaska","Ariz.",
         "Ark.","Calif.","Colo.","Conn.","Del.","Fla.","Ga.",
         "Hawaii","Idaho","Ill.","Ind.","Iowa","Kan.","Ky.",
         "La.","Md.","Maine","Mass.","Mich.","Minn.","Miss.",
         "Mo.","Mont.","Neb.","Nev.","N.H.","N.J.","N.M.",
         "N.Y.","N.C.","N.D.","Ohio","Okla.","Ore.","Pa.",
         "R.I.","S.C.","S.D.","Tenn.","Texas","Utah","Vt.",
         "Va.","Wash.","W.Va.","Wis.","Wyo."),
  state = c("AL","AK","AZ","AR","CA",
            "CO","CT","DE","FL","GA","HI","ID","IL","IN","IA",
            "KS","KY","LA","MD","ME","MA","MI","MN","MS",
            "MO","MT","NE","NV","NH","NJ","NM","NY","NC","ND",
            "OH","OK","OR","PA","RI","SC","SD","TN","TX",
            "UT","VT","VA","WA","WV","WI","WY")
)

d_ndists = map_dfr(state.abb, function(abbr) {
  tibble(state=abbr, ndists=attr(load_50state_map(abbr), "ndists"))
})

d_seats |>
  filter(!state %in% c("AK", "DE", "ND", "SD", "VT", "WY")) |>
  mutate(control = factor(recode_control[control], levels=recode_control),
         diff = seats_enac - seats) |>
  left_join(ap_abbr, by="state") |>
  left_join(d_ndists, by="state") |>
  mutate(label = str_glue("{ap} ({ndists})")) |>
  bind_rows(d_natl) |>
ggplot(aes(seats - seats_enac, reorder(label, diff), color=control)) +
  geom_vline(xintercept=0, linewidth=0.4) +
  stat_slabinterval(.width=0.0, fatten_point=0) +
  annotate("rect", fill="#00000011", xmin=-Inf, xmax=Inf, ymin=44.5, ymax=Inf) +
  stat_slabinterval(interval_size_range=c(0.8, 2.0), fatten_point=1.0) +
  scale_color_manual(name="Control of redistricting", values=PAL_control, na.value="black") +
  scale_x_continuous("Difference of seats won under enacted plan\ncompared to non-partisan baseline", expand=c(0, 0),
                     labels=label_seats) +
  labs(y="← Favors Republicans          States, ordered by seat difference          Favors Democrats →") +
  coord_cartesian(xlim=c(-2.25, 5.5)) +
  theme_bw(base_size=8, base_family="Arial") +
  theme(legend.position=c(0.97, 0.95),
        legend.justification=c(1, 1),
        legend.background=element_blank(),
        legend.key=element_blank(),
        legend.key.height=unit(15, "pt"),
        legend.title=element_text(size=8, face="bold"),
        legend.text=element_text(size=8),
        axis.text.y=element_text(face="bold", color="black"))

ggsave(here("paper/figures/state_sum.pdf"), height=5.5, width=4, device=cairo_pdf)

d_seats |>
  group_by(state) |>
  summarize(bigger = mean(seats > seats_enac) > 0.975,
            smaller = mean(seats < seats_enac) > 0.975,
            reject = bigger | smaller) |>
  pull() |> sum()


# summary data extract
d_natl = group_by(d_seats, draw) |>
  summarize(state = "NATIONAL",
            ap = "NATIONAL",
            ndists = 435,
            seats_enac = sum(seats_enac),
            seats = sum(seats),
            diff = seats_enac - seats,
            control = NA_character_)

d_seats |>
  filter(!state %in% c("AK", "DE", "ND", "SD", "VT", "WY")) |>
  mutate(control = factor(recode_control[control], levels=recode_control),
         diff = seats_enac - seats) |>
  left_join(ap_abbr, by="state") |>
  left_join(d_ndists, by="state") |>
  bind_rows(d_natl) |>
  group_by(state, ap, ndists, dseats_enac=seats_enac, control) |>
  summarize(point_interval(diff),
            .groups="drop") |>
  rename(diff = y,
         diff_low = ymin,
         diff_high = ymax) |>
  select(-.width, -.point, -.interval) |>
  write_csv("~/Desktop/state_sum.csv")


d_natl = group_by(d_seats, draw) |>
  summarize(label = "NATIONAL",
            scale = sd(seats),
            seats_enac = sum(seats_enac),
            seats = sum(seats),
            z = 1000,
            share = 1000,
            diff = 1000,
            ndists = 435,
            control = NA_character_)

label_z = function(x) {
  case_when(x == 0 ~ "0",
            x < 0 ~ number(-x, 1, prefix="D: "),
            x > 0 ~ number(x, 1, prefix="R: "))
}

# appendix: normalized to z scores
d_app = d_seats |>
  filter(!state %in% c("AK", "DE", "ND", "SD", "VT", "WY")) |>
  group_by(state) |>
  mutate(loc = mean(seats),
         scale = sd(seats)) |>
  ungroup() |>
  left_join(d_ndists, by="state") |>
  mutate(control = factor(recode_control[control], levels=recode_control),
         share = (seats_enac - seats) / ndists,
         z = (seats_enac - seats) / scale) |>
  left_join(ap_abbr, by="state") |>
  mutate(label = str_glue("{ap} ({ndists})")) |>
  bind_rows(d_natl)

p1 = ggplot(d_app, aes((seats - seats_enac) / scale, reorder(label, z), color=control)) +
  geom_vline(xintercept=0, linewidth=0.4) +
  stat_slabinterval(.width=0.0, alpha=0.0, fatten_point=0) +
  annotate("rect", fill="#00000011", xmin=-Inf, xmax=Inf, ymin=44.5, ymax=Inf) +
  stat_slabinterval(interval_size_range=c(0.8, 2.0), fatten_point=1.0, show_slab=FALSE) +
  scale_color_manual(name="Control of redistricting", values=PAL_control, na.value="black") +
  scale_x_continuous("Z-score of seats won under enacted plan\ncompared to non-partisan baseline",
                     expand=c(0, 0), labels=label_z) +
  labs(y="← Favors Republicans          States, ordered by z-score          Favors Democrats →") +
  coord_cartesian(xlim=c(-12, 10)) +
  guides(color="none") +
  theme_bw(base_size=8, base_family="Arial") +
  theme(plot.margin=margin(r=10),
        axis.text.y=element_text(face="bold", color="black"))

p2 = ggplot(d_app, aes((seats - seats_enac) / ndists, reorder(label, share), color=control)) +
  geom_vline(xintercept=0, linewidth=0.4) +
  stat_slabinterval(.width=0.0, alpha=0.0, fatten_point=0) +
  annotate("rect", fill="#00000011", xmin=-Inf, xmax=Inf, ymin=44.5, ymax=Inf) +
  stat_slabinterval(interval_size_range=c(0.8, 2.0), fatten_point=1.0, show_slab=FALSE) +
  scale_color_manual(name="Control of redistricting", values=PAL_control, na.value="black") +
  scale_x_continuous("Difference in share of seats won under enacted plan\ncompared to non-partisan baseline",
                     expand=c(0, 0), labels=label_party_pct(midpoint=0, reverse=TRUE)) +
  labs(y="← Favors Republicans          States, ordered by seat share difference          Favors Democrats →") +
  coord_cartesian(xlim=c(-0.18, 0.21)) +
  theme_bw(base_size=8, base_family="Arial") +
  theme(legend.position=c(0.995, 0.95),
        legend.justification=c(1, 1),
        legend.background=element_blank(),
        legend.key=element_blank(),
        legend.key.height=unit(15, "pt"),
        legend.title=element_text(size=8, face="bold"),
        legend.text=element_text(size=8),
        axis.text.y=element_text(face="bold", color="black"))

ggsave(here("paper/figures/state_sum_norm.pdf"), plot=p1+p2, height=5, width=7, device=cairo_pdf)
